Petals to the Metal
Getting Started with TPUs on Kaggle
- TFRecord basics
- Hyperparameters
- Data
- Model
- Optimizer
- Callbacks
- Training
- Visual Validation
- 1% Better Everyday
%%writefile conditional_cell_extension.py
def run_if(line, cell=None):
'''Execute current line/cell if line evaluates to True.'''
if not eval(line):
return
get_ipython().ex(cell)
def load_ipython_extension(shell):
'''Registers the run_if magic when the extension loads.'''
shell.register_magic_function(run_if, 'line_cell')
def unload_ipython_extension(shell):
'''Unregisters the run_if magic when the extension unloads.'''
del shell.magics_manager.magics['cell']['run_if']
%reload_ext conditional_cell_extension
import numpy as np
import pandas as pd
import seaborn as sns
import albumentations as A
import matplotlib.pyplot as plt
import os, gc, cv2, random, re
import warnings, math, sys, json, pprint, pdb
import tensorflow as tf
from tensorflow.keras import backend as K
import tensorflow_hub as hub
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score,precision_score, recall_score, confusion_matrix
warnings.simplefilter('ignore')
print(f"Using TensorFlow v{tf.__version__}")
def seed_everything(seed=0):
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
GOOGLE = 'google.colab' in str(get_ipython())
KAGGLE = not GOOGLE
project_name = 'tpu-getting-started'
root_path = '/content/gdrive/MyDrive/' if GOOGLE else '/'
input_path = f'{root_path}kaggle/input/{project_name}/'
working_path = f'{input_path}working/' if GOOGLE else '/kaggle/working/'
os.makedirs(working_path, exist_ok=True)
os.chdir(working_path)
os.listdir(input_path)
GCS_PATTERN = 'gs://flowers-public/*/*.jpg'
GCS_OUTPUT = 'gs://flowers-public/tfrecords-jpeg-192x192-2/flowers'
SHARDS = 16
TARGET_SIZE = [192, 192]
CLASSES = [b'daisy', b'dandelion', b'roses', b'sunflowers', b'tulips']
def decode_image_and_label(filename):
bits = tf.io.read_file(filename)
image = tf.image.decode_jpeg(bits)
label = tf.strings.split(tf.expand_dims(filename, axis=-1), sep='/')
#label = tf.strings.split(filename, sep='/')
label = label.values[-2]
label = tf.cast((CLASSES==label), tf.int8)
return image, label
filenames = tf.data.Dataset.list_files(GCS_PATTERN, seed=16)
for x in filenames.take(3): print(x)
def show_images(ds):
_,axs = plt.subplots(3,3,figsize=(16,16))
for ((x, y), ax) in zip(ds.take(9), axs.flatten()):
ax.imshow(x.numpy().astype(np.uint8))
ax.set_title(np.argmax(y))
ax.axis('off')
ds0 = filenames.map(decode_image_and_label, num_parallel_calls=AUTOTUNE)
show_images(ds0)
def resize_and_crop_image(image, label):
# Resize and crop using "fill" algorithm:
# always make sure the resulting image
# is cut out from the source image so that
# it fills the TARGET_SIZE entirely with no
# black bars and a preserved aspect ratio.
w = tf.shape(image)[0]
h = tf.shape(image)[1]
tw = TARGET_SIZE[1]
th = TARGET_SIZE[0]
resize_crit = (w * th) / (h * tw)
image = tf.cond(resize_crit < 1,
lambda: tf.image.resize(image, [w*tw/w, h*tw/w]), # if true
lambda: tf.image.resize(image, [w*th/h, h*th/h]) # if false
)
nw = tf.shape(image)[0]
nh = tf.shape(image)[1]
image = tf.image.crop_to_bounding_box(image, (nw - tw) // 2, (nh - th) // 2, tw, th)
return image, label
ds1 = ds0.map(resize_and_crop_image, num_parallel_calls=AUTOTUNE)
show_images(ds1)
Speed test: too slow
Google Cloud Storage is capable of great throughput but has a per-file access penalty. Run the cell below and see that throughput is around 8 images per second. That is too slow. Training on thousands of individual files will not work. We have to use the TFRecord format to group files together.
%%time
for image,label in ds1.batch(8).take(10):
print("Image batch shape {} {}".format(
image.numpy().shape,
[np.argmax(lbl) for lbl in label.numpy()]))
def recompress_image(image, label):
height = tf.shape(image)[0]
width = tf.shape(image)[1]
image = tf.cast(image, tf.uint8)
image = tf.image.encode_jpeg(image, optimize_size=True, chroma_downsampling=False)
return image, label, height, width
IMAGE_SIZE = len(tf.io.gfile.glob(GCS_PATTERN))
SHARD_SIZE = math.ceil(1.0 * IMAGE_SIZE / SHARDS)
ds2 = ds1.map(recompress_image, num_parallel_calls=AUTOTUNE)
ds2 = ds2.batch(SHARD_SIZE)
Why TFRecords?
TPUs have eight cores which act as eight independent workers. We can get data to each core more efficiently by splitting the dataset into multiple files or shards. This way, each core can grab an independent part of the data as it needs.
The most convenient kind of file to use for sharding in TensorFlow is a TFRecord. A TFRecord is a binary file that contains sequences of byte-strings. Data needs to be serialized (encoded as a byte-string) before being written into a TFRecord.
The most convenient way of serializing data in TensorFlow is to wrap the data with tf.Example. This is a record format based on Google's protobufs but designed for TensorFlow. It's more or less like a dict with some type annotations
x = tf.constant([[1,2], [3, 4]], dtype=tf.uint8)
print(x)
x_in_bytes = tf.io.serialize_tensor(x)
print(x_in_bytes)
print(tf.io.parse_tensor(x_in_bytes, out_type=tf.uint8))
A TFRecord is a sequence of bytes, so we have to turn our data into byte-strings before it can go into a TFRecord. We can use tf.io.serialize_tensor to turn a tensor into a byte-string and tf.io.parse_tensor to turn it back. It's important to keep track of your tensor's datatype (in this case tf.uint8) since you have to specify it when parsing the string back to a tensor again
gs:// domain to write to.
#def _bytestring_feature(list_of_bytestrings):
# return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings))
#
#def _int_feature(list_of_ints): # int64
# return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints))
#
#def _float_feature(list_of_floats): # float32
# return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats))
#
#def to_tfrecord(tfrec_filewriter, img_bytes, label, height, width):
# id = np.argmax(np.array(CLASSES)==label)
# one_hot = np.eye(len(CLASSES))[id]
# feature = {
# "image": _bytestring_feature([img_bytes]), # one image in the list
# "id": _int_feature([id]), # one class in the list
# "label": _bytestring_feature([label]), # fixed length (1) list of strings, the text label
# "size" : _int_feature([height, width]), # fixed length (2) list of ints
# "one_hot": _float_feature(one_hot.tolist())# variable length list of floats, n=len(CLASSES)
# }
# return tf.train.Example(features=tf.train.Features(feature=feature))
#print("Writing TFRecords")
#for shard_id, (image, label, height, width) in ds2.enumerate():
# shard_size = image.numpy().shape[0]
# filename = GCS_OUTPUT + "{:02d}-{}tfrec".format(shard_id, shard_size)
#
# with tf.io.TFRecordWriter(filename) as outfile:
# for i in range(shard_size):
# example = to_tfrecord(out_file,
# image.numpy()[i],
# label.numpy()[i],
# height.numpy()[i],
# width.numpy()[i])
# out_file.write(example.SerializeToString())
# print("Wrote file {} containing {} records".format(filename, shard_size))
def read_tfrecord(example):
features = {
"image": tf.io.FixedLenFeature([], tf.string), # tf.string = bytestring (not text string)
"class": tf.io.FixedLenFeature([], tf.int64), # shape [] means scalar
# additional (not very useful) fields to demonstrate TFRecord writing/reading of different types of data
"label": tf.io.FixedLenFeature([], tf.string), # one bytestring
"size": tf.io.FixedLenFeature([2], tf.int64), # two integers
"one_hot_class": tf.io.VarLenFeature(tf.float32) # a certain number of floats
}
# decode the TFRecord
example = tf.io.parse_single_example(example, features)
# FixedLenFeature fields are now ready to use: exmple['size']
# VarLenFeature fields require additional sparse_to_dense decoding
image = tf.image.decode_jpeg(example['image'], channels=3)
image = tf.reshape(image, [*TARGET_SIZE, 3])
class_num = example['class']
label = example['label']
height = example['size'][0]
width = example['size'][1]
one_hot_class = tf.sparse.to_dense(example['one_hot_class'])
return image, class_num, label, height, width, one_hot_class
option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False
filenames = tf.io.gfile.glob(GCS_OUTPUT + "*tfrec")
ds3 = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
ds3 = (ds3.with_options(option_no_order)
.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
.shuffle(30))
ds3_to_show = ds3.map(lambda image, id, label, height, width, one_hot: (image, label))
show_images(ds3_to_show)
%%time
for image, class_num, label, height, width, one_hot_class in ds3.batch(8).take(10):
print("Image batch shape {} {}".format(
image.numpy().shape,
[lbl.decode('utf8') for lbl in label.numpy()]))
BASE_MODEL = 'efficientnet_b3' #@param ["'efficientnet_b3'", "'efficientnet_b4'", "'efficientnet_b2'"] {type:"raw", allow-input: true}
HEIGHT = 300#@param {type:"number"}
WIDTH = 300#@param {type:"number"}
CHANNELS = 3#@param {type:"number"}
IMG_SIZE = (HEIGHT, WIDTH, CHANNELS)
EPOCHS = 50#@param {type:"number"}
BATCH_SIZE = 32 * strategy.num_replicas_in_sync #@param {type:"raw"}
print("Use {} with input size {}".format(BASE_MODEL, IMG_SIZE))
print("Train on batch size of {} for {} epochs".format(BATCH_SIZE, EPOCHS))
%%run_if {KAGGLE}
from kaggle_datasets import KaggleDatasets
GCS_PATH = KaggleDatasets().get_gcs_path(project_name)
GCS_PATH += '/tfrecords-jpeg-512x512'
print(f"Sourcing images from {GCS_PATH}")
%%run_if {GOOGLE}
#@title {run: "auto", display-mode: "form"}
GCS_PATH = 'gs://kds-c6b9829baa483a13a169c7cbe266341fb8c9b5ba36843af37a093a4c' #@param {type: "string"}
GCS_PATH += '/tfrecords-jpeg-512x512' #@param {type: "string"}
print(f"Sourcing images from {GCS_PATH}")
CLASSES = ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'wild geranium',
'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle',
'snapdragon', "colt's foot", 'king protea', 'spear thistle', 'yellow iris',
'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower','giant white arum lily',
'fire lily', 'pincushion flower', 'fritillary', 'red ginger', 'grape hyacinth',
'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke', 'sweet william',
'carnation', 'garden phlox', 'love in the mist', 'cosmos', 'alpine sea holly',
'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose',
'barberton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue',
'wallflower', 'marigold', 'buttercup', 'daisy', 'common dandelion',
'petunia', 'wild pansy', 'primula', 'sunflower', 'lilac hibiscus',
'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia', 'pink-yellow dahlia',
'cautleya spicata', 'japanese anemone', 'black-eyed susan', 'silverbush', 'californian poppy',
'osteospermum', 'spring crocus', 'iris', 'windflower', 'tree poppy',
'gazania', 'azalea', 'water lily', 'rose', 'thorn apple',
'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium',
'frangipani', 'clematis', 'hibiscus', 'columbine', 'desert-rose',
'tree mallow', 'magnolia', 'cyclamen ', 'watercress', 'canna lily',
'hippeastrum ', 'bee balm', 'pink quill', 'foxglove', 'bougainvillea',
'camellia', 'mallow', 'mexican petunia', 'bromelia', 'blanket flower',
'trumpet creeper', 'blackberry lily', 'common tulip', 'wild rose']
NCLASSES = len(CLASSES)
print(f"This dataset has {NCLASSES} labels!")
def decode_image(image_data):
image = tf.image.decode_jpeg(image_data, channels=CHANNELS)
image = (tf.cast(image, tf.float32) if GOOGLE
else tf.cast(image, tf.float32) / 255.0)
image = tf.image.random_crop(image, IMG_SIZE)
return image
def collate_labeled_tfrecord(example):
LABELED_TFREC_FORMAT = {
"image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
"class": tf.io.FixedLenFeature([], tf.int64), # shape [] means single element
}
example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
image = decode_image(example['image'])
label = tf.cast(example['class'], tf.int32)
return image, label
def process_unlabeled_tfrecord(example):
UNLABELED_TFREC_FORMAT = {
"image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
"id": tf.io.FixedLenFeature([], tf.string), # shape [] means single element
}
example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
image = decode_image(example['image'])
idnum = example['id']
return image, idnum
def count_data_items(filenames):
n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1))
for filename in filenames]
return np.sum(n)
train_filenames = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
valid_filenames = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
test_filenames = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec')
print("Number of train set: {}\n"
"Number of valid set: {}\n"
"Number of test set: {}\n"
.format(count_data_items(train_filenames),
count_data_items(valid_filenames),
count_data_items(test_filenames)))
# data augmentation @cdeotte kernel:
# https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96
def transform_rotation(image, height, rotation):
# input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
# output - image randomly rotated
DIM = height
XDIM = DIM%2 #fix for size 331
rotation = rotation * tf.random.uniform([1],dtype='float32')
# CONVERT DEGREES TO RADIANS
rotation = math.pi * rotation / 180.
# ROTATION MATRIX
c1 = tf.math.cos(rotation)
s1 = tf.math.sin(rotation)
one = tf.constant([1],dtype='float32')
zero = tf.constant([0],dtype='float32')
rotation_matrix = tf.reshape(tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3])
# LIST DESTINATION PIXEL INDICES
x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
z = tf.ones([DIM*DIM],dtype='int32')
idx = tf.stack( [x,y,z] )
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx2 = K.dot(rotation_matrix,tf.cast(idx,dtype='float32'))
idx2 = K.cast(idx2,dtype='int32')
idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
# FIND ORIGIN PIXEL VALUES
idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
d = tf.gather_nd(image, tf.transpose(idx3))
return tf.reshape(d,[DIM,DIM,3])
def transform_shear(image, height, shear):
# input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
# output - image randomly sheared
DIM = height
XDIM = DIM%2 #fix for size 331
shear = shear * tf.random.uniform([1],dtype='float32')
shear = math.pi * shear / 180.
# SHEAR MATRIX
one = tf.constant([1],dtype='float32')
zero = tf.constant([0],dtype='float32')
c2 = tf.math.cos(shear)
s2 = tf.math.sin(shear)
shear_matrix = tf.reshape(tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3])
# LIST DESTINATION PIXEL INDICES
x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
z = tf.ones([DIM*DIM],dtype='int32')
idx = tf.stack( [x,y,z] )
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx2 = K.dot(shear_matrix,tf.cast(idx,dtype='float32'))
idx2 = K.cast(idx2,dtype='int32')
idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
# FIND ORIGIN PIXEL VALUES
idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
d = tf.gather_nd(image, tf.transpose(idx3))
return tf.reshape(d,[DIM,DIM,3])
def transform_shift(image, height, h_shift, w_shift):
# input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
# output - image randomly shifted
DIM = height
XDIM = DIM%2 #fix for size 331
height_shift = h_shift * tf.random.uniform([1],dtype='float32')
width_shift = w_shift * tf.random.uniform([1],dtype='float32')
one = tf.constant([1],dtype='float32')
zero = tf.constant([0],dtype='float32')
# SHIFT MATRIX
shift_matrix = tf.reshape(tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3])
# LIST DESTINATION PIXEL INDICES
x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
z = tf.ones([DIM*DIM],dtype='int32')
idx = tf.stack( [x,y,z] )
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx2 = K.dot(shift_matrix,tf.cast(idx,dtype='float32'))
idx2 = K.cast(idx2,dtype='int32')
idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
# FIND ORIGIN PIXEL VALUES
idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
d = tf.gather_nd(image, tf.transpose(idx3))
return tf.reshape(d,[DIM,DIM,3])
def data_augment(image, label):
p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_pixel = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_shift = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
# Flips
if p_spatial >= .2:
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
# Rotates
if p_rotate > .75:
image = tf.image.rot90(image, k=3) # rotate 270º
elif p_rotate > .5:
image = tf.image.rot90(image, k=2) # rotate 180º
elif p_rotate > .25:
image = tf.image.rot90(image, k=1) # rotate 90º
if p_rotation >= .3: # Rotation
image = transform_rotation(image, height=HEIGHT, rotation=45.)
if p_shift >= .3: # Shift
image = transform_shift(image, height=HEIGHT, h_shift=15., w_shift=15.)
if p_shear >= .3: # Shear
image = transform_shear(image, height=HEIGHT, shear=20.)
# Crops
if p_crop > .4:
crop_size = tf.random.uniform([], int(HEIGHT*.7), HEIGHT, dtype=tf.int32)
image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])
elif p_crop > .7:
if p_crop > .9:
image = tf.image.central_crop(image, central_fraction=.7)
elif p_crop > .8:
image = tf.image.central_crop(image, central_fraction=.8)
else:
image = tf.image.central_crop(image, central_fraction=.9)
image = tf.image.resize(image, size=[HEIGHT, WIDTH])
# Pixel-level transforms
if p_pixel >= .2:
if p_pixel >= .8:
image = tf.image.random_saturation(image, lower=0, upper=2)
elif p_pixel >= .6:
image = tf.image.random_contrast(image, lower=.8, upper=2)
elif p_pixel >= .4:
image = tf.image.random_brightness(image, max_delta=.2)
else:
image = tf.image.adjust_gamma(image, gamma=.6)
return image, label
experimental_deterministic is set to decide whether the outputs need to be produced in deterministic order. Default: True
option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False
train_ds = tf.data.TFRecordDataset(train_filenames, num_parallel_reads=AUTOTUNE)
train_ds = (train_ds
.map(collate_labeled_tfrecord, num_parallel_calls=AUTOTUNE)
.map(data_augment, num_parallel_calls=AUTOTUNE)
.repeat()
.shuffle(2048)
.batch(BATCH_SIZE)
.prefetch(AUTOTUNE))
show_images(train_ds.take(1).unbatch())
valid_ds = tf.data.TFRecordDataset(valid_filenames, num_parallel_reads=AUTOTUNE)
valid_ds = (valid_ds
.with_options(option_no_order)
.map(collate_labeled_tfrecord, num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE)
.cache()
.prefetch(AUTOTUNE))
show_images(valid_ds.take(1).unbatch())
test_ds = tf.data.TFRecordDataset(test_filenames, num_parallel_reads=AUTOTUNE)
test_ds = (test_ds
.with_options(option_no_order)
.map(process_unlabeled_tfrecord, num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE)
.prefetch(AUTOTUNE))
Augmentation can be applied in two ways.
- Using the Keras Preprocessing Layers
- Using the
tf.imageImportant: The Keras Preprocessing Layers are currently experimental so it seems it does not have supporting TPU OpKernel yet.
#batch_augment = tf.keras.Sequential(
# [
# tf.keras.layers.experimental.preprocessing.RandomCrop(*IMG_SIZE),
# tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
# tf.keras.layers.experimental.preprocessing.RandomRotation(0.25),
# tf.keras.layers.experimental.preprocessing.RandomZoom((-0.2, 0)),
# tf.keras.layers.experimental.preprocessing.RandomContrast((0.2,0.2))
# ]
#)
#func = lambda x,y: (batch_augment(x), y)
#x = (train_ds
# .take(1)
# .map(func, num_parallel_calls=AUTOTUNE))
Now we're ready to create a neural network for classifying images! We'll use what's known as transfer learning. With transfer learning, you reuse the body part of a pretrained model and replace its' head or tail with custom layers depending on the problem that we are solving.
For this tutorial, we'll use EfficientNetb3 which is pretrained on ImageNet. Later, I might want to experiment with other models. (Xception wouldn't be a bad choice.)
straategy.scope. This context manager tells TensorFlow how to divide the work of training among the eight TPU cores. When using TensorFlow with a TPU, it’s important to define your model in strategy.sceop() context.
%%run_if {KAGGLE}
!pip install -q efficientnet
from efficientnet.tfkeras import EfficientNetB3
%%run_if {GOOGLE}
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.applications import VGG16
def build_model(base_model, num_class):
inputs = tf.keras.layers.Input(shape=IMG_SIZE)
x = base_model(inputs)
x = tf.keras.layers.Dropout(0.4)(x)
outputs = tf.keras.layers.Dense(num_class, activation="softmax", name="pred")(x)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
return model
with strategy.scope():
efficientnet = EfficientNetB3(
weights = 'imagenet' if TRAIN else None,
include_top = False,
input_shape = IMG_SIZE,
pooling='avg')
efficientnet.trainable = True
model = build_model(base_model=efficientnet, num_class=len(CLASSES))
CosineDecayRestarts function implemented in tf.keras as it seemed promising and I struggled to find the right settings (if there were any) for the ReduceLROnPlateau
STEPS = math.ceil(count_data_items(train_filenames) / BATCH_SIZE) * EPOCHS
LR_START = 1e-4 #@param {type: "number"}
LR_START *= strategy.num_replicas_in_sync
LR_MIN = 1e-5 #@param {type: "number"}
N_RESTARTS = 5#@param {type: "number"}
T_MUL = 2.0 #@param {type: "number"}
M_MUL = 1#@param {type: "number"}
STEPS_START = math.ceil((T_MUL-1)/(T_MUL**(N_RESTARTS+1)-1) * STEPS)
schedule = tf.keras.experimental.CosineDecayRestarts(
first_decay_steps=STEPS_START,
initial_learning_rate=LR_START,
alpha=LR_MIN,
m_mul=M_MUL,
t_mul=T_MUL)
x = [i for i in range(STEPS)]
y = [schedule(s) for s in range(STEPS)]
_,ax = plt.subplots(1,1,figsize=(8,5),facecolor='#F0F0F0')
ax.plot(x, y)
ax.set_facecolor('#F8F8F8')
ax.set_xlabel('iteration')
ax.set_ylabel('learning rate')
print('{:d} total epochs and {:d} steps per epoch'
.format(EPOCHS, STEPS // EPOCHS))
print(schedule.get_config())
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath='001_best_model.h5',
monitor='val_loss',
save_best_only=True),
tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
mode='min',
patience=10,
restore_best_weights=True,
verbose=1)
]
model.compile(
optimizer=tf.keras.optimizers.Adam(schedule),
loss = 'sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy']
)
model.summary()
%%run_if {GOOGLE}
def generate_unlabeled_tfrecord(example):
LABELED_TFREC_FORMAT = {
"image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
"class": tf.io.FixedLenFeature([], tf.int64), # shape [] means single element
}
example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
image = decode_image(example['image'])
image = image / 255.0
return image
%%run_if {GOOGLE}
if os.path.exists("000_normalization.h5"):
model.load_weights("000_normalization.h5")
else:
adapt_ds = (tf.data.TFRecordDataset(train_filenames, num_parallel_reads=AUTOTUNE)
.map(generate_unlabeled_tfrecord, num_parallel_calls=AUTOTUNE)
.shuffle(2048)
.batch(BATCH_SIZE)
.prefetch(AUTOTUNE))
model.get_layer('efficientnetb3').get_layer('normalization').adapt(adapt_ds)
model.save_weights("000_normalization.h5")
history = model.fit(
x=train_ds,
validation_data=valid_ds,
epochs=EPOCHS,
steps_per_epoch=STEPS//BATCH_SIZE,
callbacks=callbacks,
verbose=2
)
model.load_weights("001_best_model.h5")
#import pickle
#with open('trainHistoryDict', 'wb') as fd:
# pickle.dump(history.history, fd)
#history = pickle.load(open('trainHistoryDict', "rb"))
def show_history(history):
topics = ['loss', 'accuracy']
groups = [{k:v for (k,v) in history.items() if topic in k} for topic in topics]
_,axs = plt.subplots(1,2,figsize=(15,6),facecolor='#F0F0F0')
for topic,group,ax in zip(topics,groups,axs.flatten()):
for (_,v) in group.items(): ax.plot(v)
ax.set_facecolor('#F8F8F8')
ax.set_title(f'{topic} over epochs')
ax.set_xlabel('epoch')
ax.set_ylabel(topic)
ax.legend(['train', 'valid'], loc='best')
show_history(history)
def show_confusion_matrix(cmat, score, precision, recall):
_,ax = plt.subplots(1,1,figsize=(12,12),facecolor='#F0F0F0')
ax.matshow(cmat, cmap='Blues')
ax.set_xticks(range(len(CLASSES)),)
ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
ax.set_yticks(range(len(CLASSES)))
ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
textstr = ""
if precision: textstr += 'precision = {:.3f} '.format(precision)
if recall: textstr += '\nrecall = {:.3f} '.format(recall)
if score: textstr += '\nf1 = {:.3f} '.format(score)
if len(textstr) > 0:
props = dict(boxstyle='round', facecolor='wheat', alpha=0.2)
ax.text(0.75, 0.95, textstr, transform=ax.transAxes, fontsize=14,
verticalalignment='top', bbox=props)
plt.show()
ordered_valid_ds = tf.data.TFRecordDataset(valid_filenames, num_parallel_reads=AUTOTUNE)
ordered_valid_ds = (ordered_valid_ds
.map(collate_labeled_tfrecord, num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE)
.cache()
.prefetch(AUTOTUNE))
x_valid_ds = ordered_valid_ds.map(lambda x,y : x, num_parallel_calls=AUTOTUNE)
y_valid_ds = ordered_valid_ds.map(lambda x,y : y, num_parallel_calls=AUTOTUNE)
y_true = (y_valid_ds
.unbatch()
.batch(count_data_items(valid_filenames))
.as_numpy_iterator()
.next())
y_probs = model.predict(x_valid_ds)
y_preds = np.argmax(y_probs, axis=-1)
label_ids = range(len(CLASSES))
cmatrix = confusion_matrix(y_true, y_preds, labels=label_ids)
cmatrix = (cmatrix.T / cmatrix.sum(axis=1)).T # normalize
You might be familiar with metrics like F1-score or precision and recall. This cell will compute these metrics and display them with a plot of the confusion matrix. (These metrics are defined in the Scikit-learn module sklearn.metrics; we've imported them in the helper script for you.)
precision = precision_score(y_true, y_preds, labels=label_ids, average='macro')
recall = recall_score(y_true, y_preds, labels=label_ids, average='macro')
score = f1_score(y_true, y_preds, labels=label_ids,average='macro')
show_confusion_matrix(cmatrix, score, precision, recall)
Visual Validation
It can also be helpful to look at some examples from the validation set and see what class your model predicted. This can help reveal patterns in the kinds of images your model has trouble with. This cell will set up the validation set to display 20 images at a time -- you can change this to display more or fewer, if you like.
1% Better Everyday
reference
- Create Your First Submission
- How to use my own data source?
- TPU-speed data pipelines: tf.data.Dataset and TFRecords
todos
- Comment out the 1/255.0 in the image preprocessing
- Reorganize the notebook structure
done